{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: Improving Math Reasoning with a Custom Recipe for Automated Prompt Engineering (DSPy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "from typing import Dict, Optional\n",
    "\n",
    "import altair as alt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from dataset.dataset import get_examples, is_correct\n",
    "from scipy.stats import norm\n",
    "from tensorzero import AsyncTensorZeroGateway, InferenceResponse\n",
    "from tqdm.asyncio import tqdm_asyncio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup TensorZero Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_examples = get_examples(\"train\")\n",
    "random.shuffle(train_examples)\n",
    "\n",
    "test_examples = get_examples(\"test\")\n",
    "random.shuffle(test_examples)\n",
    "\n",
    "print(train_examples[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def solve_math_problem(\n",
    "    question: str,\n",
    "    *,\n",
    "    variant_name: Optional[str] = None,\n",
    "    dryrun: bool = False,\n",
    ") -> Optional[InferenceResponse]:\n",
    "    try:\n",
    "        return await t0.inference(\n",
    "            function_name=\"solve_math_problem\",\n",
    "            input={\"messages\": [{\"role\": \"user\", \"content\": {\"question\": question}}]},\n",
    "            cache_options={\"enabled\": \"on\"},\n",
    "            variant_name=variant_name,\n",
    "            dryrun=dryrun,\n",
    "        )\n",
    "    except Exception as e:\n",
    "        print(f\"Error ({type(e)}): {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the function below, the only feedback provided to TensorZero is whether the output of the function is correct.\n",
    "We do not provide the correct answer in cases of mistakes.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def solve_and_score_math_problem(\n",
    "    example: Dict[str, str],\n",
    "    *,\n",
    "    variant_name: Optional[str] = None,\n",
    "    dryrun: bool = False,\n",
    ") -> Optional[bool]:\n",
    "    # Run the TensorZero function on the example\n",
    "    response = await solve_math_problem(example[\"question\"], variant_name=variant_name, dryrun=dryrun)\n",
    "\n",
    "    # Inference failed completely, disregard this example\n",
    "    if response is None:\n",
    "        return None\n",
    "\n",
    "    # Inference succeeded, but the first block is not text, so we consider it incorrect\n",
    "    first_block = response.content[0]\n",
    "    if first_block.type != \"text\":\n",
    "        return False\n",
    "\n",
    "    # Inference succeeded, and the first block is text, so we score the example\n",
    "    correct = is_correct(first_block.text, example)\n",
    "\n",
    "    # Store the feedback in the database\n",
    "    # (skip if dryrun since we're not storing the inferences, which will trigger an error)\n",
    "    if not dryrun:\n",
    "        await t0.feedback(\n",
    "            metric_name=\"correct\",\n",
    "            value=correct,\n",
    "            inference_id=response.inference_id,\n",
    "            dryrun=dryrun,\n",
    "        )\n",
    "\n",
    "    return correct"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Examples\n",
    "\n",
    "Run the TensorZero function on the training examples, grade the answers, and store the feedback in the database.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce the number of concurrent inferences to avoid rate limiting\n",
    "MAX_CONCURRENT_INFERENCES = 100\n",
    "\n",
    "semaphore = asyncio.Semaphore(MAX_CONCURRENT_INFERENCES)\n",
    "\n",
    "\n",
    "async def run_inference(\n",
    "    example: Dict[str, str], *, variant_name: Optional[str] = None, dryrun: bool = False\n",
    ") -> Optional[bool]:\n",
    "    async with semaphore:\n",
    "        return await solve_and_score_math_problem(example, variant_name=variant_name, dryrun=dryrun)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TRAINING_INFERENCES = 1000\n",
    "\n",
    "coroutines = [run_inference(example) for example in train_examples[:NUM_TRAINING_INFERENCES]]\n",
    "\n",
    "results = await tqdm_asyncio.gather(*coroutines, desc=\"Running training inferences\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the cell below, we evaluate the accuracy of a variant on some of the test examples. If you generate a new variant, you should run this cell with the new variant name to evaluate it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TEST_INFERENCES = 200\n",
    "\n",
    "VARIANTS = [\n",
    "    \"gpt_4o_mini_baseline\",\n",
    "    # The provided optimized variant was generated by DSPy using the code below\n",
    "    # \"gpt_4o_mini_optimized\",\n",
    "]\n",
    "\n",
    "accuracies = {}\n",
    "\n",
    "for variant_name in VARIANTS:\n",
    "    # We use dryrun=True here to avoid leaking the test set into the database\n",
    "    coroutines = [\n",
    "        run_inference(example, variant_name=variant_name, dryrun=True)\n",
    "        for example in test_examples[:NUM_TEST_INFERENCES]\n",
    "    ]\n",
    "\n",
    "    results = await tqdm_asyncio.gather(*coroutines, desc=\"Running test inferences\")\n",
    "\n",
    "    # Filter out None values from results\n",
    "    total_results = len(results)\n",
    "    results = [result for result in results if result is not None]\n",
    "    success_rate = len(results) / total_results\n",
    "    print(f\"Success rate: {success_rate:.1%}\")\n",
    "\n",
    "    accuracy = sum(results) / len(results)\n",
    "    n = len(results)\n",
    "    z = norm.ppf(0.975)  # 95% confidence interval\n",
    "    margin_of_error = z * np.sqrt((accuracy * (1 - accuracy)) / n)\n",
    "\n",
    "    print(f\"Accuracy: {accuracy:.4f}\")\n",
    "    print(f\"95% Confidence Interval: ({accuracy - margin_of_error:.4f}, {accuracy + margin_of_error:.4f})\")\n",
    "\n",
    "    accuracies[variant_name] = accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_df = pd.DataFrame(\n",
    "    [{\"Variant\": variant_name, \"Score\": accuracy} for variant_name, accuracy in accuracies.items()]\n",
    ")\n",
    "\n",
    "scores_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chart = (\n",
    "    alt.Chart(scores_df)\n",
    "    .encode(\n",
    "        x=alt.X(\"Score:Q\", axis=alt.Axis(format=\"%\"), scale=alt.Scale(domain=[0, 1])),\n",
    "        y=\"Variant:N\",\n",
    "        text=alt.Text(\"Score:Q\", format=\".1%\"),\n",
    "    )\n",
    "    .properties(title=\"Metrics by Variant\")\n",
    ")\n",
    "\n",
    "chart = chart.mark_bar() + chart.mark_text(align=\"left\", dx=2)\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, we could run any TensorZero recipe to generate a new variant which might perform better using this historical data. You can go try this!\n",
    "\n",
    "Below, we include an example of how to use an external library, [DSPy](https://dspy-docs.vercel.app/), to automatically optimize a prompt for this function.\n",
    "Given that the ClickHouse database TensorZero uses easily allows for the querying of historical inference and feedback data into Pandas DataFrames, it is easy to integrate with nearly any ML library yourself.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automated Prompt Engineering with DSPy\n",
    "\n",
    "The rest of this notebook shows how we can pull data from the TensorZero data model in ClickHouse and use it to optimize a prompt for a function using DSPy.\n",
    "Given that there are many strategies for prompt optimization in DSPy, we can use the same code skeleton to try a lot of different strategies.\n",
    "However, there are a few things (table name, feedback name, chat function type, etc) that we have set specifically for this example.\n",
    "You can change them to fit your needs.\n",
    "At a high level the notebook below does the following:\n",
    "\n",
    "1. Pull data from ClickHouse and convert it into a DSPy dataset.\n",
    "2. Run a prompt optimization loop using one of the teleprompting classes supported by DSPy.\n",
    "3. Print the optimized prompt from the history so that you can write it to a minijinja file.\n",
    "\n",
    "**Note:** DSPy does not model the chat completion interface commonly used by language models. So, we only support functions that have inputs into the user prompt, that only use text output, that are single-turn functions, and that have a flat JSON schema for input, i.e. functions that take a list of primitive types as input into the user schema and output text or a flat JSON object.\n",
    "\n",
    "To get started:\n",
    "\n",
    "- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. \n",
    "- Set the `OPENAI_API_KEY` environment variable.\n",
    "- Update the following parameters to those that apply to your use case.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dspy\n",
    "from clickhouse_connect import get_client\n",
    "from dspy.datasets import Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can swap the client below for any of the ones supported [here](https://dspy.ai/learn/programming/language_models/) in case you want DSPy to use a different language model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lm_client = dspy.OpenAI(model=\"gpt-4o-mini\")\n",
    "dspy.configure(lm=lm_client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A simple function signature for the `solve_math_problem` function\n",
    "function_signature = \"input -> output\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the ClickHouse client.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert \"TENSORZERO_CLICKHOUSE_URL\" in os.environ, \"TENSORZERO_CLICKHOUSE_URL environment variable not set\"\n",
    "clickhouse_client = get_client(dsn=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Grab the dataset of examples which were successful according to the metric.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"\"\"\n",
    "SELECT \n",
    "    i.variant_name, \n",
    "    i.input, \n",
    "    i.output, \n",
    "    i.episode_id,\n",
    "    f.value\n",
    "FROM \n",
    "    ChatInference i\n",
    "JOIN \n",
    "    (SELECT\n",
    "        target_id,\n",
    "        value,\n",
    "        ROW_NUMBER() OVER (PARTITION BY target_id ORDER BY timestamp DESC) as rn\n",
    "    FROM BooleanMetricFeedback\n",
    "    WHERE\n",
    "        metric_name = 'correct'\n",
    "        AND value = true\n",
    "    ) f ON i.id = f.target_id and f.rn = 1\n",
    "WHERE \n",
    "    i.function_name = 'solve_math_problem'\n",
    "LIMIT %(max_samples)s\n",
    "\"\"\"\n",
    "\n",
    "params = {\n",
    "    \"max_samples\": 1000,\n",
    "}\n",
    "\n",
    "df = clickhouse_client.query_df(query, params)\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_dspy_compatible_inputs(input_raw: str) -> Optional[Dict[str, str]]:\n",
    "    \"\"\"\n",
    "    Checks that the input of this Inference is in the correct format for DSPy.\n",
    "    Then returns the dictionary of inputs.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        parsed_input = json.loads(input_raw)\n",
    "    except json.JSONDecodeError:\n",
    "        print(f\"Input is not valid JSON: {input_raw}\")\n",
    "        return None\n",
    "    messages = parsed_input.get(\"messages\", None)\n",
    "    if messages is None:\n",
    "        print(f\"Input contains no messages: {input_raw}\")\n",
    "        return None\n",
    "    if len(messages) != 1:\n",
    "        print(f\"Input contains more than one message: {input_raw}\")\n",
    "        return None\n",
    "    message = messages[0]\n",
    "    content = message.get(\"content\", None)\n",
    "    if content is None:\n",
    "        print(f\"Input contains no content: {input_raw}\")\n",
    "        return None\n",
    "    if len(content) != 1:\n",
    "        print(f\"Input must contain exactly one content item: {input_raw}\")\n",
    "        return None\n",
    "    content = content[0]\n",
    "    if content[\"type\"] != \"text\":\n",
    "        print(f\"Input contains non-text content: {input_raw}\")\n",
    "        return None\n",
    "    value = content.get(\"value\", None)\n",
    "    if value is None:\n",
    "        print(f\"Input contains no value: {input_raw}\")\n",
    "        return None\n",
    "    return value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse the input column into a list of dicts and create a new DataFrame with parsed content\n",
    "parsed_inputs = df[\"input\"].apply(parse_dspy_compatible_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_outputs(output_raw: str) -> Optional[str]:\n",
    "    try:\n",
    "        parsed_output = json.loads(output_raw)\n",
    "    except json.JSONDecodeError:\n",
    "        print(f\"Output is not valid JSON: {output_raw}\")\n",
    "        return None\n",
    "    if len(parsed_output) != 1:\n",
    "        print(f\"Output contains more than one message: {output_raw}\")\n",
    "        return None\n",
    "    message = parsed_output[0]\n",
    "    if message[\"type\"] != \"text\":\n",
    "        print(f\"Output contains non-text content: {output_raw}\")\n",
    "        return None\n",
    "    value = message.get(\"text\", None)\n",
    "    if value is None:\n",
    "        print(f\"Output contains no value: {output_raw}\")\n",
    "        return None\n",
    "    return value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse the output column and create a new DataFrame with parsed content\n",
    "parsed_outputs = df[\"output\"].apply(parse_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = pd.concat([parsed_inputs, parsed_outputs], axis=1)\n",
    "all_data = all_data.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TensorZeroDSPyDataset(Dataset):\n",
    "    def __init__(\n",
    "        self,\n",
    "        df: pd.DataFrame,\n",
    "        dev_fraction: float = 0.1,\n",
    "    ):\n",
    "        # Randomly shuffle the DataFrame\n",
    "        df = df.sample(frac=1, random_state=42).reset_index(drop=True)\n",
    "\n",
    "        # Extract the 'question' string from the 'input' column\n",
    "        df[\"input\"] = df[\"input\"].apply(lambda x: x[\"question\"])\n",
    "\n",
    "        # Calculate the number of samples for train and dev sets\n",
    "        total_samples = len(df)\n",
    "        dev_samples = int(total_samples * dev_fraction)\n",
    "        train_samples = total_samples - dev_samples\n",
    "\n",
    "        # Split the DataFrame\n",
    "        train_df = df.iloc[:train_samples]\n",
    "        dev_df = df.iloc[train_samples:]\n",
    "\n",
    "        # Split the DataFrame\n",
    "        self._train = train_df.to_dict(orient=\"records\")\n",
    "        self._dev = dev_df.to_dict(orient=\"records\")\n",
    "        self._test = None\n",
    "        self.train_size = len(self._train)\n",
    "        self.dev_size = len(self._dev)\n",
    "        super().__init__(\n",
    "            train_size=self.train_size,\n",
    "            dev_size=self.dev_size,\n",
    "            test_size=0,\n",
    "        )\n",
    "\n",
    "        print(f\"Train set: {len(self._train)} samples\")\n",
    "        print(f\"Dev set: {len(self._dev)} samples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TensorZeroDSPyDataset(all_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dspy_function = dspy.Predict(function_signature)\n",
    "\n",
    "\n",
    "class Predictor(dspy.Module):\n",
    "    def __init__(self, signature: dspy.Signature):\n",
    "        super().__init__()\n",
    "        self.prog = dspy.Predict(signature)\n",
    "\n",
    "    def forward(self, **inputs):\n",
    "        return self.prog(**inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can swap the teleprompter with any of the teleprompting classes supported by DSPy [here](https://dspy-docs.vercel.app/docs/building-blocks/optimizers).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dspy.teleprompt import LabeledFewShot\n",
    "\n",
    "teleprompter = LabeledFewShot(k=5)\n",
    "optimized_function = teleprompter.compile(Predictor(function_signature), trainset=dataset.train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dataset.dev[0][\"input\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We run an example inference to get the prompt from the history.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_function(input=\"test_input\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's parse out the prompt from the history.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dspy_prompt = lm_client.history[-1][\"prompt\"]\n",
    "# we parse the actual inference input out of the prompt (DSPy does not separate the prompt from the inputs in this history)\n",
    "dspy_prompt = \"---\".join(dspy_prompt.split(\"---\")[:-1])\n",
    "\n",
    "# DSPy does not know the output format for GSM8k, so we add it manually\n",
    "merged_prompt = f\"\"\"\n",
    "You are tasked with solving a math problem. You will be given an open-ended question that should require arithmetic to solve.\n",
    "\n",
    "Feel free to work through the problem step-by-step in your response, but once you have found the solution, please complete your response with:\n",
    "#### your_answer\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "{dspy_prompt}\n",
    "\n",
    "---\n",
    "\n",
    "REMEMBER: End your response with `#### your_answer`, where `your_answer` should be an integer with no other punctuation.\n",
    "\"\"\".strip()\n",
    "\n",
    "print(merged_prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Write the optimized user prompt to a minijinja file and try it out! You can skip to the training cell and use the new variant name to evaluate.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
